import numpy as np
import tensorflow as tf

from niftynet.layer.rgb_histogram_equilisation import \
    RGBHistogramEquilisationLayer
from niftynet.utilities.util_import import require_module
from tests.niftynet_testcase import NiftyNetTestCase

IMAGE_DATA = \
    np.array([[[0.49803922, 0.19215687, 0.3529412 ],
               [0.49411765, 0.16862746, 0.31764707],
               [0.5254902 , 0.21176471, 0.3647059 ],
               [0.45882353, 0.21176471, 0.38039216],
               [0.44705883, 0.19215687, 0.39607844],
               [0.44705883, 0.18431373, 0.39607844],
               [0.43137255, 0.18039216, 0.3882353 ],
               [0.42352942, 0.16470589, 0.34901962],
               [0.41960785, 0.14509805, 0.3647059 ],
               [0.46666667, 0.18039216, 0.38431373]],
              [[0.4509804 , 0.15686275, 0.34901962],
               [0.43137255, 0.16862746, 0.3529412 ],
               [0.4745098 , 0.21176471, 0.4       ],
               [0.47058824, 0.2784314 , 0.4509804 ],
               [0.46666667, 0.25490198, 0.4627451 ],
               [0.4509804 , 0.24705882, 0.41960785],
               [0.42745098, 0.17254902, 0.37254903],
               [0.4392157 , 0.16078432, 0.36078432],
               [0.48235294, 0.1882353 , 0.41960785],
               [0.53333336, 0.2901961 , 0.48235294]],
              [[0.45882353, 0.20784314, 0.38431373],
               [0.48235294, 0.21960784, 0.40784314],
               [0.4627451 , 0.24705882, 0.40392157],
               [0.45490196, 0.23529412, 0.4117647 ],
               [0.41568628, 0.14901961, 0.34901962],
               [0.4509804 , 0.16862746, 0.34901962],
               [0.45882353, 0.2       , 0.41568628],
               [0.5019608 , 0.24313726, 0.41960785],
               [0.5058824 , 0.23529412, 0.4392157 ],
               [0.53333336, 0.29411766, 0.46666667]],
              [[0.52156866, 0.21960784, 0.41568628],
               [0.47058824, 0.16862746, 0.36862746],
               [0.4392157 , 0.1882353 , 0.38039216],
               [0.4117647 , 0.18431373, 0.38039216],
               [0.40784314, 0.15686275, 0.36078432],
               [0.43137255, 0.14901961, 0.3529412 ],
               [0.5176471 , 0.2901961 , 0.47058824],
               [0.5137255 , 0.2509804 , 0.4627451 ],
               [0.45882353, 0.21176471, 0.39215687],
               [0.44313726, 0.18039216, 0.38039216]],
              [[0.47843137, 0.19215687, 0.36078432],
               [0.44313726, 0.14901961, 0.37254903],
               [0.40392157, 0.13333334, 0.32941177],
               [0.41568628, 0.12941177, 0.34901962],
               [0.43529412, 0.14509805, 0.38431373],
               [0.49411765, 0.23529412, 0.44313726],
               [0.5294118 , 0.3019608 , 0.45882353],
               [0.50980395, 0.25882354, 0.4392157 ],
               [0.43529412, 0.19607843, 0.3529412 ],
               [0.39215687, 0.13333334, 0.3254902 ]],
              [[0.44705883, 0.14117648, 0.34117648],
               [0.39607844, 0.12156863, 0.3137255 ],
               [0.4117647 , 0.14117648, 0.34509805],
               [0.44705883, 0.15686275, 0.3764706 ],
               [0.5058824 , 0.20784314, 0.40392157],
               [0.5294118 , 0.25490198, 0.42745098],
               [0.5137255 , 0.25882354, 0.42352942],
               [0.48235294, 0.20392157, 0.41568628],
               [0.39215687, 0.13725491, 0.3137255 ],
               [0.36078432, 0.11372549, 0.29411766]],
              [[0.41960785, 0.13333334, 0.3647059 ],
               [0.43529412, 0.1882353 , 0.38431373],
               [0.4509804 , 0.16862746, 0.3647059 ],
               [0.50980395, 0.23529412, 0.44705883],
               [0.56078434, 0.28235295, 0.45882353],
               [0.5372549 , 0.27450982, 0.42745098],
               [0.5176471 , 0.27450982, 0.47843137],
               [0.48235294, 0.24705882, 0.4       ],
               [0.39215687, 0.14901961, 0.3254902 ],
               [0.38431373, 0.13725491, 0.3137255 ]],
              [[0.45882353, 0.1764706 , 0.4       ],
               [0.50980395, 0.21568628, 0.42352942],
               [0.50980395, 0.20784314, 0.42352942],
               [0.56078434, 0.29803923, 0.49411765],
               [0.5294118 , 0.26666668, 0.43137255],
               [0.54509807, 0.3254902 , 0.50980395],
               [0.5254902 , 0.3137255 , 0.50980395],
               [0.42745098, 0.16078432, 0.32156864],
               [0.39607844, 0.12156863, 0.32156864],
               [0.3647059 , 0.09803922, 0.30588236]],
              [[0.50980395, 0.21568628, 0.4117647 ],
               [0.54509807, 0.27450982, 0.43137255],
               [0.5529412 , 0.23137255, 0.4392157 ],
               [0.5568628 , 0.26666668, 0.4627451 ],
               [0.5372549 , 0.2784314 , 0.47058824],
               [0.5529412 , 0.30980393, 0.5176471 ],
               [0.49411765, 0.21568628, 0.3882353 ],
               [0.42352942, 0.1764706 , 0.3529412 ],
               [0.38039216, 0.10980392, 0.3019608 ],
               [0.38039216, 0.09411765, 0.28627452]],
              [[0.49803922, 0.1882353 , 0.3764706 ],
               [0.54901963, 0.23529412, 0.41568628],
               [0.5568628 , 0.2901961 , 0.4862745 ],
               [0.5529412 , 0.28235295, 0.4509804 ],
               [0.5411765 , 0.29411766, 0.5137255 ],
               [0.5176471 , 0.24705882, 0.4509804 ],
               [0.41960785, 0.14901961, 0.32156864],
               [0.42352942, 0.15686275, 0.38431373],
               [0.38431373, 0.12156863, 0.31764707],
               [0.38039216, 0.10588235, 0.32156864]]], dtype=np.float32)


class RGBEquilisationTest(NiftyNetTestCase):
    """
    Test for RGBHistogramEquilisationLayer
    """

    def test_equilisation(self):
        cv2 = require_module('cv2', mandatory=False)

        if cv2 is None:
            self.skipTest('requires cv2 module')
            return

        def _get_histogram(img):
            inten = cv2.cvtColor(img[::-1], cv2.COLOR_BGR2YUV)[...,0]*255

            return np.histogram(inten, 32, [0, 256])[0]

        hist_before = _get_histogram(IMAGE_DATA)

        layer = RGBHistogramEquilisationLayer(image_name='image')
        orig_shape = list(IMAGE_DATA.shape)
        input_shape = orig_shape[:2] + [1]*2 + [3]
        img, _ = layer(IMAGE_DATA.reshape(input_shape))

        hist_after = _get_histogram(img.reshape(orig_shape))

        self.assertGreater(hist_before.astype(np.float32).std(),
                           hist_after.astype(np.float32).std())

        img, _ = layer({'image': IMAGE_DATA.reshape(input_shape)})

        hist_after = _get_histogram(img['image'].reshape(orig_shape))

        self.assertGreater(hist_before.astype(np.float32).std(),
                           hist_after.astype(np.float32).std())

        img = (255*IMAGE_DATA).astype(np.uint8)
        img, _ = layer({'image': IMAGE_DATA.reshape(input_shape)})

        hist_after = _get_histogram(img['image'].reshape(orig_shape))

        self.assertGreater(hist_before.astype(np.float32).std(),
                           hist_after.astype(np.float32).std())


if __name__ == "__main__":
    tf.test.main()
